# Copyright (c) OpenMMLab. All rights reserved.
from torch.utils.data import ConcatDataset as TorchConcatDataset

from xtuner.registry import BUILDER


class ConcatDataset(TorchConcatDataset):
    def __init__(self, datasets):
        datasets_instance = []
        for cfg in datasets:
            datasets_instance.append(BUILDER.build(cfg))
        super().__init__(datasets=datasets_instance)

    def __repr__(self):
        main_str = "Dataset as a concatenation of multiple datasets. \n"
        main_str += ",\n".join([f"{repr(dataset)}" for dataset in self.datasets])
        return main_str
